{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "import torch.optim as optim\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from tqdm.notebook import tqdm\n",
    "from tqdm import tqdm\n",
    "\n",
    "class DictObj(object):\n",
    "    # 私有变量是map\n",
    "    # 设置变量的时候 初始化设置map\n",
    "    def __init__(self, mp):\n",
    "        self.map = mp\n",
    "        # print(mp)\n",
    "\n",
    "# set 可以省略 如果直接初始化设置\n",
    "    def __setattr__(self, name, value):\n",
    "        if name == 'map':# 初始化的设置 走默认的方法\n",
    "            # print(\"init set attr\", name ,\"value:\", value)\n",
    "            object.__setattr__(self, name, value)\n",
    "            return\n",
    "        # print('set attr called ', name, value)\n",
    "        self.map[name] = value\n",
    "# 之所以自己新建一个类就是为了能够实现直接调用名字的功能。\n",
    "    def __getattr__(self, name):\n",
    "        # print('get attr called ', name)\n",
    "        return  self.map[name]\n",
    "\n",
    "\n",
    "Config = DictObj({\n",
    "    'poem_path' : \"./tang.npz\",\n",
    "    'model_save_path':'./poem.pth',\n",
    "    'embedding_dim':100,\n",
    "    'hidden_dim':1024,\n",
    "    'lr':0.001,\n",
    "    'LSTM_layers':3\n",
    "})"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class PoemDataSet(Dataset):\n",
    "    def __init__(self,poem_path,seq_len):\n",
    "        self.seq_len = seq_len\n",
    "        self.poem_path = poem_path\n",
    "        self.poem_data, self.ix2word, self.word2ix = self.get_raw_data()\n",
    "        self.no_space_data = self.filter_space()\n",
    "\n",
    "    def __getitem__(self, idx:int):\n",
    "        txt = self.no_space_data[idx*self.seq_len : (idx+1)*self.seq_len]\n",
    "        label = self.no_space_data[idx*self.seq_len + 1 : (idx+1)*self.seq_len + 1] # 将窗口向后移动一个字符就是标签\n",
    "        txt = torch.from_numpy(np.array(txt)).long()\n",
    "        label = torch.from_numpy(np.array(label)).long()\n",
    "        return txt,label\n",
    "\n",
    "    def __len__(self):\n",
    "        return int(len(self.no_space_data) / self.seq_len)\n",
    "\n",
    "    def filter_space(self): # 将空格的数据给过滤掉，并将原始数据平整到一维\n",
    "        t_data = torch.from_numpy(self.poem_data).view(-1)\n",
    "        flat_data = t_data.numpy()\n",
    "        no_space_data = []\n",
    "        for i in flat_data:\n",
    "            if (i != 8292 ):\n",
    "                no_space_data.append(i)\n",
    "        return no_space_data\n",
    "    def get_raw_data(self):\n",
    "        datas = np.load(self.poem_path,allow_pickle=True)  #numpy 1.16.2  以上引入了allow_pickle\n",
    "        data = datas['data']\n",
    "        ix2word = datas['ix2word'].item()\n",
    "        word2ix = datas['word2ix'].item()\n",
    "        return data, ix2word, word2ix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "(57580, 125)\n[['<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'\n  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'\n  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'\n  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'\n  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'\n  '<' '<' '知' '尔' '新' '从' '海' '外' '来' '，' '晓' '窗' '吟' '坐' '思' '难' '裁' '。'\n  '堪' '怜' '时' '复' '撼' '书' '幌' '，' '似' '报' '故' '园' '花' '欲' '开' '。' '<']]\n"
    }
   ],
   "source": [
    "def view_data(poem_path):\n",
    "    datas = np.load(poem_path, allow_pickle=True)\n",
    "    data = datas['data']\n",
    "    ix2word = datas['ix2word'].item()\n",
    "    word2ix = datas['word2ix'].item()\n",
    "    word_data = np.zeros((1,data.shape[1]),dtype=np.str) # 这样初始化后值会保留第一一个字符，所以输出中'<START>' 变成了'<'\n",
    "    row = np.random.randint(data.shape[0])\n",
    "    for col in range(data.shape[1]):\n",
    "        word_data[0,col] = ix2word[data[row,col]]\n",
    "    print(data.shape) #(57580, 125)\n",
    "    print(word_data)#随机查看\n",
    "view_data(Config.poem_path)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    },
    "tags": []
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "class PoemDataSet(Dataset):\n",
    "    def __init__(self,poem_path,seq_len):\n",
    "        self.seq_len = seq_len\n",
    "        self.poem_path = poem_path\n",
    "        self.poem_data, self.ix2word, self.word2ix = self.get_raw_data()\n",
    "        self.no_space_data = self.filter_space()\n",
    "\n",
    "    def __getitem__(self, idx:int):\n",
    "        txt = self.no_space_data[idx*self.seq_len : (idx+1)*self.seq_len]\n",
    "        label = self.no_space_data[idx*self.seq_len + 1 : (idx+1)*self.seq_len + 1] # 将窗口向后移动一个字符就是标签\n",
    "        txt = torch.from_numpy(np.array(txt)).long()\n",
    "        label = torch.from_numpy(np.array(label)).long()\n",
    "        return txt,label\n",
    "\n",
    "    def __len__(self):\n",
    "        return int(len(self.no_space_data) / self.seq_len)\n",
    "\n",
    "    def filter_space(self): # 将空格的数据给过滤掉，并将原始数据平整到一维\n",
    "        t_data = torch.from_numpy(self.poem_data).view(-1)\n",
    "        flat_data = t_data.numpy()\n",
    "        no_space_data = []\n",
    "        for i in flat_data:\n",
    "            if (i != 8292 ):\n",
    "                no_space_data.append(i)\n",
    "        return no_space_data\n",
    "    def get_raw_data(self):\n",
    "        datas = np.load(self.poem_path,allow_pickle=True)\n",
    "        data = datas['data']\n",
    "        ix2word = datas['ix2word'].item()\n",
    "        word2ix = datas['word2ix'].item()\n",
    "        return data, ix2word, word2ix"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "(tensor([8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817,  648, 7121, 1542, 6483,\n         7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703,\n         7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534,   70, 3788, 3823,\n         7435, 4907, 5567,  201, 2834, 1519, 7066,  782,  782, 2063, 2031,  846]),\n tensor([6731, 4770, 1787, 8118, 7577, 7066, 4817,  648, 7121, 1542, 6483, 7435,\n         7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703, 7435,\n         6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534,   70, 3788, 3823, 7435,\n         4907, 5567,  201, 2834, 1519, 7066,  782,  782, 2063, 2031,  846, 7435]))"
     },
     "metadata": {},
     "execution_count": 21
    }
   ],
   "source": [
    "poem_ds = PoemDataSet(Config.poem_path, 48)\n",
    "ix2word = poem_ds.ix2word\n",
    "word2ix = poem_ds.word2ix\n",
    "poem_ds[0]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "<torch.utils.data.dataloader.DataLoader at 0x1c303a6c748>"
     },
     "metadata": {},
     "execution_count": 22
    }
   ],
   "source": [
    "poem_loader =  DataLoader(poem_ds,\n",
    "                     batch_size=16,\n",
    "                     shuffle=True,\n",
    "                     num_workers=0)\n",
    "poem_loader"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "outputs": [],
   "source": [
    "class PoemModel(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_dim):\n",
    "        super(PoemModel, self).__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.embeddings = nn.Embedding(vocab_size, embedding_dim)#vocab_size:就是ix2word这个字典的长度。\n",
    "        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=Config.LSTM_layers,\n",
    "                            batch_first=True,dropout=0, bidirectional=False)\n",
    "        self.fc1 = nn.Linear(self.hidden_dim,2048)\n",
    "        self.fc2 = nn.Linear(2048,4096)\n",
    "        self.fc3 = nn.Linear(4096,vocab_size)\n",
    "#         self.linear = nn.Linear(self.hidden_dim, vocab_size)# 输出的大小是词表的维度，\n",
    "\n",
    "    def forward(self, input, hidden=None):\n",
    "        embeds = self.embeddings(input)  # [batch, seq_len] => [batch, seq_len, embed_dim]\n",
    "        batch_size, seq_len = input.size()\n",
    "        if hidden is None:\n",
    "            h_0 = input.data.new(Config.LSTM_layers*1, batch_size, self.hidden_dim).fill_(0).float()\n",
    "            c_0 = input.data.new(Config.LSTM_layers*1, batch_size, self.hidden_dim).fill_(0).float()\n",
    "        else:\n",
    "            h_0, c_0 = hidden\n",
    "        output, hidden = self.lstm(embeds, (h_0, c_0))#hidden 是h,和c 这两个隐状态\n",
    "        output = torch.tanh(self.fc1(output))\n",
    "        output = torch.tanh(self.fc2(output))\n",
    "        output = self.fc3(output)\n",
    "        output = output.reshape(batch_size * seq_len, -1)\n",
    "        return output,hidden"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [],
   "source": [
    "class AvgrageMeter(object):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.cnt = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.sum += val * n\n",
    "        self.cnt += n\n",
    "        self.avg = self.sum / self.cnt\n",
    "## topk的准确率计算\n",
    "def accuracy(output, label, topk=(1,)):\n",
    "    maxk = max(topk)\n",
    "    batch_size = label.size(0)\n",
    "\n",
    "    # 获取前K的索引\n",
    "    _, pred = output.topk(maxk, 1, True, True) #使用topk来获得前k个的索引\n",
    "    pred = pred.t() # 进行转置\n",
    "    # eq按照对应元素进行比较 view(1,-1) 自动转换到行为1,的形状， expand_as(pred) 扩展到pred的shape\n",
    "    # expand_as 执行按行复制来扩展，要保证列相等\n",
    "    correct = pred.eq(label.view(1, -1).expand_as(pred)) # 与正确标签序列形成的矩阵相比，生成True/False矩阵\n",
    "#     print(correct)\n",
    "\n",
    "    rtn = []\n",
    "    for k in topk:\n",
    "        correct_k = correct[:k].view(-1).float().sum(0) # 前k行的数据 然后平整到1维度，来计算true的总个数\n",
    "        rtn.append(correct_k.mul_(100.0 / batch_size)) # mul_() ternsor 的乘法  正确的数目/总的数目 乘以100 变成百分比\n",
    "    return rtn"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def train( epochs, train_loader, device, model, criterion, optimizer,scheduler):\n",
    "    model.train()\n",
    "    top1 = AvgrageMeter()\n",
    "    model = model.to(device)\n",
    "    for epoch in range(epochs):\n",
    "        train_loss = 0.0\n",
    "        train_loader = tqdm(train_loader)\n",
    "        train_loader.set_description('[%s%04d/%04d %s%f]' % ('Epoch:', epoch + 1, epochs, 'lr:', scheduler.get_lr()[0]))\n",
    "        for i, data in enumerate(train_loader, 0):  # 0是下标起始位置默认为0\n",
    "            inputs, labels = data[0].to(device), data[1].to(device)\n",
    "#             print(' '.join(ix2word[inputs.view(-1)[k] for k in inputs.view(-1).shape.item()]))\n",
    "            labels = labels.view(-1) # 因为outputs经过平整，所以labels也要平整来对齐\n",
    "            # 初始为0，清除上个batch的梯度信息\n",
    "            optimizer.zero_grad()\n",
    "            outputs,hidden = model(inputs)\n",
    "            loss = criterion(outputs,labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            _,pred = outputs.topk(1)\n",
    "#             print(get_word(pred))\n",
    "#             print(get_word(labels))\n",
    "            prec1, prec2= accuracy(outputs, labels, topk=(1,2))\n",
    "            n = inputs.size(0)\n",
    "            top1.update(prec1.item(), n)\n",
    "            train_loss += loss.item()\n",
    "            postfix = {'train_loss': '%.6f' % (train_loss / (i + 1)), 'train_acc': '%.6f' % top1.avg}\n",
    "            train_loader.set_postfix(log=postfix)\n",
    "        scheduler.step()\n",
    "        torch.save(model.state_dict(), Config.model_save_path)\n",
    "\n",
    "    print('Finished Training')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "PoemModel(\n  (embeddings): Embedding(8293, 100)\n  (lstm): LSTM(100, 1024, num_layers=3, batch_first=True)\n  (fc1): Linear(in_features=1024, out_features=2048, bias=True)\n  (fc2): Linear(in_features=2048, out_features=4096, bias=True)\n  (fc3): Linear(in_features=4096, out_features=8293, bias=True)\n)"
     },
     "metadata": {},
     "execution_count": 26
    }
   ],
   "source": [
    "model = PoemModel(len(word2ix),\n",
    "                  embedding_dim=Config.embedding_dim,\n",
    "                  hidden_dim=Config.hidden_dim)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "if os.path.exists(Config.model_save_path):\n",
    "    model.load_state_dict(torch.load(Config.model_save_path))\n",
    "\n",
    "epochs = 30\n",
    "optimizer = optim.Adam(model.parameters(), lr=Config.lr)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 10,gamma=0.001)#学习率调整\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "model"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "[Epoch:0001/0030 lr:0.001000]:   0%|          | 13/4079 [00:06<30:10,  2.25it/s, log={'train_loss': '8.997909', 'train_acc': '3.425481'}]"
    },
    {
     "output_type": "error",
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-27-9e728ef202fc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpoem_loader\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mscheduler\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-25-4f2914b9adad>\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(epochs, train_loader, device, model, criterion, optimizer, scheduler)\u001b[0m\n\u001b[0;32m     24\u001b[0m             \u001b[0mprec1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mprec2\u001b[0m\u001b[1;33m=\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtopk\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m             \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 26\u001b[1;33m             \u001b[0mtop1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mprec1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     27\u001b[0m             \u001b[0mtrain_loss\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     28\u001b[0m             \u001b[0mpostfix\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m{\u001b[0m\u001b[1;34m'train_loss'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m'%.6f'\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mtrain_loss\u001b[0m \u001b[1;33m/\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'train_acc'\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m'%.6f'\u001b[0m \u001b[1;33m%\u001b[0m \u001b[0mtop1\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mavg\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train(epochs, poem_loader, device, model, criterion, optimizer,scheduler)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    },
    "tags": []
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "3.7.6-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}